from torchair._ge_concrete_graph.ge_converter.converter_utils import *


@declare_supported(
    [
        Support([I32(2, 2, 2), I32(2, 3), I32(2, 3)], 1),
        Support([F32(2, 2, 2), F32(2, 3), F32(2, 3)], 1.),
        Support([F16(2, 2, 2), F16(2, 3), F16(2, 3)], 1.),
        Support([BF16(2, 2, 2), BF16(2, 3), BF16(2, 3)], 1.),
    ]
)
@register_fx_node_ge_converter(torch.ops.aten._foreach_minimum.Scalar)
def conveter_aten__foreach_minimum_scalar(
    self: List[Tensor],
    scalar: Union[Number, Tensor],
    meta_outputs: List[TensorSpec] = None):
    """NB: aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]"""
    if len(self) > 0:
        if self[0].dtype == DataType.DT_BF16:
            scalar = dtype_promote(scalar, target_dtype=DataType.DT_FLOAT)
        else:
            scalar = dtype_promote(scalar, target_dtype=self[0].dtype)
    return ge.ForeachMinimumScalar(self, scalar)


@declare_supported(
    [
        Support([F32(2, 2, 2), F32(2, 3), F32(2, 3)], [1., 1., 1.]),
        Support([F16(2, 2, 2), F16(2, 3), F16(2, 3)], [1., 1., 1.]),
        Support([BF16(2, 2, 2), BF16(2, 3), BF16(2, 3)], [1., 1., 1.]),
        Support([I32(2, 2, 2), I32(2, 3)], [1, 1]),
    ]
)
@register_fx_node_ge_converter(torch.ops.aten._foreach_minimum.ScalarList)
def conveter_aten__foreach_minimum_scalarlist(
    self: List[Tensor],
    scalars: Union[List[Number], Tensor],
    meta_outputs: List[TensorSpec] = None):
    """NB: aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]"""
    if len(scalars) > 0 and isinstance(scalars[0], int):
        scalars = dtype_promote(scalars, target_dtype=DataType.DT_INT64)
    return ge.ForeachMinimumScalarList(self, scalars)


@declare_supported(
    [
        Support([F32(2, 2, 2), F32(2, 3), F32(2, 3)], [F32(2, 2, 2), F32(2, 3), F32(2, 3)]),
        Support([F16(2, 2, 2), F16(2, 3), F16(2, 3)], [F16(2, 2, 2), F16(2, 3), F16(2, 3)]),
        Support([BF16(2, 2, 2), BF16(2, 3), BF16(2, 3)], [BF16(2, 2, 2), BF16(2, 3), BF16(2, 3)]),
    ]
)
@register_fx_node_ge_converter(torch.ops.aten._foreach_minimum.List)
def conveter_aten__foreach_minimum_list(
    self: List[Tensor],
    other: List[Tensor],
    meta_outputs: List[TensorSpec] = None):
    """NB: aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]"""
    return ge.ForeachMinimumList(self, other)
